package io.zbus.data.impl;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import javax.sql.DataSource;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.TransactionCallbackWithoutResult;
import org.springframework.transaction.support.TransactionTemplate;

import com.alibaba.fastjson.JSONObject;

import io.zbus.data.api.Db;
import io.zbus.data.impl.TableKit.TableColumn;
import io.zbus.data.impl.meta.MetaData;
import io.zbus.data.impl.meta.MetaReader;
import io.zbus.data.impl.meta.Table;
import io.zbus.data.kit.JsonKit;
import io.zbus.rpc.annotation.Route; 

public class SpringDb implements Db {  
	private static final Logger logger = LoggerFactory.getLogger(SpringDb.class);
	 
	private JdbcTemplate jdbc; 
	private TransactionTemplate tx;  
	private MetaData meta;
	private MetaReader reader;
	
	public SpringDb(DataSource dataSource) {
		jdbc = new JdbcTemplate(dataSource);
		DataSourceTransactionManager txManager = new DataSourceTransactionManager(dataSource);
		tx = new TransactionTemplate(txManager);
		
		reader = new MetaReader(dataSource);  
	}
	
	private void initMeta() {
		if(meta == null) { //if init failed, try this time
			synchronized (this) {
				try {
					meta = reader.reflect(); 
				} catch (Exception e) {
					logger.error(e.getMessage(), e);
				}
			} 
		}
	}
	
	private Table getTableInfo(String table) { 
		if(meta == null) { //if init failed, try this time
			initMeta();
		}
		Table t = meta.tables.get(table);
		if(t == null) {
			throw new IllegalArgumentException("Table("+table + ") Not Found");
		}
		return t;
	}
	
	private String r(String word) {
		return this.meta.dialect.resolve(word);
	}   
	
	//==================queryList===============
	@Override
	public List<Map<String, Object>> queryList(String sql, Object[] args) {  
		return jdbc.queryForList(sql, args);
	}
	
	@Override
	public List<Map<String, Object>> queryList(String sql) { 
		return jdbc.queryForList(sql, new Object[0]);
	}
	
	@Override
	public List<Map<String, Object>> queryList(String sql, Object jsonWhere) {
		return queryList(sql, new Object[0], jsonWhere);
	}
	
	@Override
	public List<Map<String, Object>> queryList(String sql, Object[] args, Object jsonWhere) { 
		List<Object> paramList = new ArrayList<>();
		sql = buildSql(sql, args, jsonWhere, paramList); 
		return queryList(sql, paramList.toArray());
	}
	
	private String buildSql(String sql, Object[] args, Object jsonWhere, List<Object> paramList) {
		JSONObject json = JsonKit.convert(jsonWhere, JSONObject.class);
		JsonConditionParser parser = new JsonConditionParser(meta, sql); 
		for(Object arg : args) {
			paramList.add(arg);
		}
		String exp = parser.parse(json, paramList);
		if(!TableKit.whereExists(sql)) {
			sql += " WHERE " + exp;
		} else {
			sql += " AND (" + exp + ")";
		}  
		return sql;
	}
	
	@Override
	public <T> List<T> queryList(Class<T> type, String sql, Object[] args) { 
		return JsonKit.convertList(queryList(sql, args), type);
	}
	
	@Override
	public <T> List<T> queryList(Class<T> type, String sql) {
		return queryList(type, sql, new Object[0]);
	}
	 
	
	//==================queryPage===============
	@Override
	public List<Map<String, Object>> queryPage(String sql, int page, int size, Object[] args) { 
		int offset = page*size;
		int limit = size;
		initMeta(); //if meta not found
		return queryList(meta.dialect.paging(sql, offset, limit), args); 
	}
	
	@Override
	public List<Map<String, Object>> queryPage(String sql, int page, int size) {
		return queryPage(sql, page, size, new Object[0]);
	} 
	
	@Override
	public List<Map<String, Object>> queryPage(String sql, int page, int size, Object[] args, Object jsonWhere) {
		List<Object> paramList = new ArrayList<>();
		sql = buildSql(sql, args, jsonWhere, paramList); 
		return queryPage(sql, page, size, paramList.toArray()); 
	}
	
	@Override
	public List<Map<String, Object>> queryPage(String sql, int page, int size, Object jsonWhere) {
		return queryPage(sql, page, size, new Object[0], jsonWhere);
	}
	
	@Override
	public <T> List<T> queryPage(Class<T> type, String sql, int page, int size, Object[] args) {
		return JsonKit.convertList(queryPage(sql, page, size), type);
	}
	
	@Override
	public <T> List<T> queryPage(Class<T> type, String sql, int page, int size) {
		return queryPage(type, sql, page, size, new Object[0]);
	}
	
	
	//==================queryMap===============
	@Override
	public Map<String, Object> queryOne(String sql, Object[] args) { 
		return jdbc.queryForMap(sql, args);
	}
	
	@Override
	public Map<String, Object> queryOne(String sql) { 
		return queryOne(sql, new Object[0]);
	}  
	
	@Override
	public Map<String, Object> queryOne(String sql, Object[] args, Object jsonWhere) {
		List<Object> paramList = new ArrayList<>();
		sql = buildSql(sql, args, jsonWhere, paramList); 
		return queryOne(sql, paramList.toArray()); 
	}
	
	@Override
	public Map<String, Object> queryOne(String sql, Object jsonWhere) {
		return queryOne(sql, new Object[0], jsonWhere);
	}
	
	@Override
	public <T> T queryOne(Class<T> type, String sql) {
		return JsonKit.convert(queryOne(sql), type);
	}
	
	@Override
	public <T> T queryOne(Class<T> type, String sql, Object[] args) {
		return queryOne(type, sql, new Object[0]);
	}
	
	//==================queryObject===============
	@Override
	public Object queryObject(String sql, Object[] args) {
		return jdbc.queryForObject(sql, args, Object.class); 
	} 
	
	@Override
	public Object queryObject(String sql) {
		return queryObject(sql, new Object[0]); 
	} 
	
	@Override
	public Object queryObject(String sql, Object[] args, Object jsonWhere) {
		List<Object> paramList = new ArrayList<>();
		sql = buildSql(sql, args, jsonWhere, paramList); 
		return queryObject(sql, paramList.toArray());  
	}
	
	@Override
	public Object queryObject(String sql, Object jsonWhere) {
		return queryObject(sql, new Object[0], jsonWhere);
	}
	
	@Override
	public <T> T queryObject(Class<T> type, String sql, Object[] args) { 
		return JsonKit.convert(queryObject(sql, args), type);
	}
	
	@Override
	public <T> T queryObject(Class<T> type, String sql) { 
		return queryObject(type, sql, new Object[0]);
	}
	
	
	//==================select===============

	@Override
	public List<Map<String, Object>> select(String table, Integer page, Integer size, Object jsonWhere) { 
		Table t = getTableInfo(table);
		List<Object> valueList = new ArrayList<>();
		String sql = buildSql(t, jsonWhere, valueList); 
		
		if(t.primaryKeys.keySet().isEmpty()) { 
			sql += " ORDER BY " + r(t.columnsInSeq.get(0).columnName); //first column as order by
		} else {   
			List<String> pks = new ArrayList<>(); 
			for(String pk : t.primaryKeys.keySet()) pks.add(r(pk));
			sql += " ORDER BY " + String.join(",", pks);
		} 
		
		if(page == null) page = 0;
		if(size == null) size = 100;
		return queryPage(sql, page, size, valueList.toArray());
	}
	
	private String buildSql(Table t, Object jsonWhere, List<Object> paramList) {
		String sql = String.format("SELECT * FROM %s", r(t.tableName) ); 
		String where = buildWhere(t, jsonWhere, paramList);
		if(where != null) {
			sql += " WHERE " + where;
		}
		return sql;
	}
	
	private String buildWhere(Table t, Object jsonWhere, List<Object> paramList) {
		JSONObject json = JsonKit.convert(jsonWhere, JSONObject.class);
		JsonConditionParser parser = new JsonConditionParser(meta); 
		parser.setTable(t.tableName, t.tableName); 
		return parser.parse(json, paramList); 
	}
	
	@Override
	public List<Map<String, Object>> select(String table) {
		return select(table, null);
	}
	@Override
	public List<Map<String, Object>> select(String table, Integer page, Integer size) {
		return select(table, page, size, null);
	}
	
	@Override
	public List<Map<String, Object>> select(String table, Object jsonWhere) {
		return select(table, 0, 100, jsonWhere);
	}
	
	@Override
	public <T> List<T> select(Class<T> type, String table, Integer page, Integer size, Object jsonWhere) { 
		return JsonKit.convertList(select(table, page, size, jsonWhere), type);
	}
	
	@Override
	public <T> List<T> select(Class<T> type, String table, Integer page, Integer size) { 
		return select(type, table, page, size, null);
	}
	
	@Override
	public <T> List<T> select(Class<T> type, String table, Object jsonWhere) {
		return select(type, table, null, null, jsonWhere);
	}
	
	@Override
	public <T> List<T> select(Class<T> type, String table) {
		return select(type, table, null, null, null);
	}
	
	//==================execute===============
	@Override
	public int execute(String sql, Object[] args) {  
		return jdbc.update(sql, args); 
	}
	
	@Override
	public int execute(String sql) {  
		return execute(sql, new Object[0]); 
	}
	
	@Override
	public int execute(String sql, Object[] args, Object jsonWhere) {
		List<Object> paramList = new ArrayList<>();
		sql = buildSql(sql, args, jsonWhere, paramList); 
		return execute(sql, paramList.toArray());
	}
	
	@Override
	public int execute(String sql, Object jsonWhere) { 
		return execute(sql, new Object[0], jsonWhere);
	}
	
	
	//==================insert===============
	@SuppressWarnings("unchecked")
	@Override
	public <T> int insert(String table, T record) { 
		Map<String, Object> r = JsonKit.convert(record, Map.class); 
		
		Table t = getTableInfo(table);
		final SqlExecute p = buildInsert(t, r);
		KeyHolder holder = new GeneratedKeyHolder();
		
		int rc = jdbc.update(new PreparedStatementCreator() {
			@Override
			public PreparedStatement createPreparedStatement(Connection connection) throws SQLException {
				PreparedStatement ps = connection.prepareStatement(p.sql, Statement.RETURN_GENERATED_KEYS);
				for(int i=0; i< p.args.length; i++) {
					ps.setObject(i+1, p.args[i]);
				}
				return ps;
			}
		}, holder);  
		
		Object autoKey = holder.getKey();
		if(t.autoColumn != null && autoKey != null) {
			r.put(t.autoColumn.columnName, autoKey);
		} 
		return rc; 
	}
	
	private SqlExecute buildInsert(Table t, Map<String, Object> record) { 
		List<String> colList = new ArrayList<>();
		List<String> qmarkList = new ArrayList<>(); 
		List<Object> valList = new ArrayList<>(); 
		
		for(Entry<String, Object> e : record.entrySet()){
			String colName = e.getKey();
			Object colVal = e.getValue();
			if(!t.columns.containsKey(colName)) {
				String msg = String.format("Key=(%s) is not a column in Table(%s):\n%s", 
						colName,  t.tableName, JsonKit.toJSONPrintString(record)); 
				logger.debug(msg);
				continue;
			}
			
			colList.add(r(colName));
			qmarkList.add("?");
			valList.add(colVal); 
		}  
		
		String sql = String.format("INSERT INTO %s(%s) VALUES(%s)", 
				r(t.tableName), 
				String.join(",", colList),
				String.join(",", qmarkList)
		); 
		
		SqlExecute pair = new SqlExecute();
		pair.sql = sql;
		pair.args = valList.toArray();
		
		return pair;
	}
	
	

	@Override
	public int save(String table, Map<String, Object> record) {
		Table t = getTableInfo(table);
		Set<String> keySet = t.primaryKeys.keySet(); 
		boolean primaryKeyPopulatd = true;
		for(String key : keySet) {
			if(!record.containsKey(key)) {
				primaryKeyPopulatd = false;
				break;
			}
		}
		if(primaryKeyPopulatd) {
			return update(table, record);
		} else {
			return insert(table, record);
		} 
	}
	
	@SuppressWarnings("unchecked")
	@Override
	public <T> int save(String table, T record) {
		Map<String, Object> r = JsonKit.convert(record, Map.class);
		return save(table, r);
	}
	
	//==================Batch operations===============
	@SuppressWarnings("unchecked")
	@Override
	public <T> List<Integer> batchInsert(String table, List<T> record) {
		if(record.isEmpty()) return new ArrayList<>();
		Table t = getTableInfo(table);
		
		final List<Object[]> batchArgs = new ArrayList<>();
		String sql = null;
		for(T r : record) {
			Map<String, Object> m = JsonKit.convert(r, Map.class);
			SqlExecute p = buildInsert(t,m);
			sql = p.sql; //TODO make sure all data same format
			batchArgs.add(p.args);
		} 
		return batchUpdate(sql, batchArgs);
	}
	 
	
	public List<Integer> batchUpdate(String sql, Object batchArgs){
		List<Object[]> args = JsonKit.convertList(batchArgs, Object[].class);
		return batchUpdate(sql, args);
	}
	
	public List<Integer> batchUpdate(String sql, List<Object[]>  batchArgs){ 
		int[] data = jdbc.batchUpdate(sql, batchArgs);
		List<Integer> res = new ArrayList<>();
		for(int d : data) res.add(d);
		return res;
	}  
	 
	@SuppressWarnings("unchecked")
	@Override
	public <T> int update(String table, T record, String where, Object[] args) {
		Table t = getTableInfo(table);
		List<String> setList = new ArrayList<>();   
		List<Object> setValList = new ArrayList<>();   
		
		Map<String, Object> m = JsonKit.convert(record, Map.class);
		for(Entry<String, Object> e : m.entrySet()){
			String colName = e.getKey();
			Object colVal = e.getValue();
			if(!t.columns.containsKey(colName)) {
				String msg = String.format("Key=(%s) is not a column in Table(%s):\n%s", 
						colName, t.tableName, JsonKit.toJSONPrintString(record));
				
				logger.debug(msg);
				continue;
			}   
			setList.add(r(colName) + "=?");  
			setValList.add(colVal);
		}   
		
		String sql = String.format("UPDATE %s SET %s WHERE %s", 
				r(t.tableName), 
				String.join(", ", setList),
				where
		); 
		
		for(Object arg : args) setValList.add(arg); 
		
		return jdbc.update(sql, setValList.toArray());
	} 
	 
	 
	@SuppressWarnings("unchecked")
	@Override
	public <T> int update(String table, T record) {
		Map<String, Object> r = JsonKit.convert(record, Map.class);
		return update(table, r);
	}
	 
	@Override
	public <T> int update(String table, T record, String where) { 
		return update(table, record, where, new Object[0]);
	} 
	
	@Override
	public <T> int update(String table, T record, Object jsonWhere) {
		List<Object> paramList = new ArrayList<>();
		Table t = getTableInfo(table);
		String where = buildWhere(t, jsonWhere, paramList);
		return update(table, record, where, paramList.toArray());
	}
	
	//==================remove===============
	@Override
	public int remove(String table, String where, Object[] args) {
		Table t = getTableInfo(table);
		String sql = String.format("DELETE FROM %s WHERE %s", 
				r(t.tableName),  
				where
		); 
		
		return jdbc.update(sql, args);
	}
	
	@Override
	public int remove(String table, String where) {
		return remove(table, where, new Object[0]);
	}
	 
	@Override
	public int remove(String table, Object jsonWhere) { 
		List<Object> paramList = new ArrayList<>();
		Table t = getTableInfo(table);
		String where = buildWhere(t, jsonWhere, paramList);
		return remove(table, where, paramList.toArray());
	}
	
	protected TableColumn validateTableColumn(String tableColumn) {
		TableColumn tc = TableKit.tableColumn(tableColumn);
		if(tc.table == null) {
			String error = String.format("Table is missing, format <table>.<column>: %s",  tableColumn);
			throw new IllegalArgumentException(error); 
		}
		Table t = getTableInfo(tc.table);
		if(t == null) {
			String error = String.format("Table(%s) Not Found: %s", tc.table, tableColumn);
			throw new IllegalArgumentException(error);
		}
		if(!t.columns.containsKey(tc.column)){
			String error = String.format("Column(%s) not in Table(%s): %s", tc.column, tc.table, tableColumn);
			throw new IllegalArgumentException(error);
		}
		return tc;
	}

	
	private Map<String, List<Map<String, Object>>> groupResultByKey(List<Map<String, Object>> data, String key){
		Map<String, List<Map<String, Object>>> res = new HashMap<>();
		for(Map<String, Object> r : data) {
			if(!r.containsKey(key)) continue;
			
			String groupKey = r.get(key).toString();
			List<Map<String, Object>> groupList = res.get(groupKey);
			if(groupList == null) {
				groupList = new ArrayList<>();
				res.put(groupKey, groupList);
			}
			groupList.add(r);
		}
		
		return res; 
	}
	
	//==================link operations===============  
	@Override
	public void linkQuery(Object data, Object query) {
		LinkQuery linkQuery = JsonKit.convert(query, LinkQuery.class); 
		linkQuery(data, linkQuery);
	} 
	 
	@SuppressWarnings("unchecked")
	@Override
	public void linkQuery(Object record, LinkQuery query) {
		if(record instanceof List) {
			linkQuery((List<Object>)record, query);
			return;
		}
		linkQuery(Arrays.asList(record), query);
	}
	
	public void linkQuery(List<Object> records, LinkQuery query) {
		//validateTableColumn(query.sqlCondKey); //SQL injection, sqlCondKey should not accept external input
		String groupKey = query.linkColumn; 
		TableColumn tc = TableKit.tableColumn(groupKey);
		groupKey = tc.column; //ignore table
		
		String[] bb = query.path.split("[.]");
		
		final String propertyName = bb[bb.length-1].trim(); 
		
		String[] propertyPaths = new String[bb.length-1];
		String[] keyPaths = new String[bb.length];
		for(int i=0;i<keyPaths.length-1;i++) {
			propertyPaths[i] = keyPaths[i] = bb[i].trim();
		}
		keyPaths[keyPaths.length-1] = query.key.trim();  
		
		Set<String> keys = TableKit.findKeys(records, keyPaths);
		String sql = query.sql;
		
		if(TableKit.whereExists(query.sql)) {
			sql += " AND ";
		} else {
			sql += " WHERE ";
		} 
		
		List<String> qmarkList = new ArrayList<>();
		for(int i=0;i<keys.size();i++) qmarkList.add("?");
		sql += query.linkColumn + " in (" + String.join(",", qmarkList) + ")"; //TODO single change to =, paging for large
		
		List<Object> params = new ArrayList<>(); 
		if(query.args != null) {
			params.addAll(query.args);
		}
		params.addAll(keys);
		
		//trigger to Database in one time
		List<Map<String, Object>> data = queryList(sql, params.toArray()); 
		
		//group result according to group column
		Map<String, List<Map<String, Object>>> groupedResult = groupResultByKey(data, groupKey); 
		
		//find out all leaf data needs to add new property
		List<Object> expanded = new ArrayList<>();
		TableKit.findByPath(expanded, records, propertyPaths, 0);
		 
		for(Object leaf : expanded) {
			Object keyValue = TableKit.get(leaf, query.key);
			if(keyValue == null) continue; 
			
			List<Map<String, Object>> joinValues = groupedResult.get(keyValue.toString()); 
			Object propertyValue = joinValues;
			if(query.isOne != null && query.isOne) {
				if(joinValues != null) {
					if(joinValues.size() == 1) {
						propertyValue = joinValues.get(0);
					}
				}
			}  
			TableKit.set(leaf, propertyName, propertyValue); 
		} 
	}  

	//==================transactions===============
	@Override
	public void tx(Runnable runner) {
		tx.execute(new TransactionCallbackWithoutResult() {   
			@Override
			protected void doInTransactionWithoutResult(TransactionStatus status) {
				runner.run();
			}
		}); 
	} 
	
	@Route(exclude = true)
	public void setJdbcTemplate(JdbcTemplate jdbc) {
		this.jdbc = jdbc;
	}
	
	@Route(exclude = true)
	public void setTransactionTemplate(TransactionTemplate tx) {
		this.tx = tx;
	}
	
	public static class SqlExecute {
		public String sql;
		public Object[] args;
	}

}
